package com.demo.utils;

import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.buffer.DataBufferUtils;
import reactor.core.publisher.Flux;

import java.nio.charset.StandardCharsets;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author gy
 * @date 2025/9/23
 */
@Slf4j
public class FluxUtils {

    // 添加统计计数器
    private final AtomicInteger totalProcessedBlocks = new AtomicInteger(0);
    private final AtomicInteger corruptedBlocks = new AtomicInteger(0);
    private final AtomicInteger fixedBlocks = new AtomicInteger(0);

    /**
     * 创建支持UTF-8字符缓冲的数据流处理器
     * 解决大数据包被拆分导致UTF-8字符截断的问题
     */
    public Flux<String> createUtf8BufferedFlux(Flux<org.springframework.core.io.buffer.DataBuffer> dataBufferFlux) {
        return dataBufferFlux
                .scan(new BufferState(), (state, dataBuffer) -> {
                    try {
                        // 读取当前数据包的字节
                        byte[] currentBytes = new byte[dataBuffer.readableByteCount()];
                        dataBuffer.read(currentBytes);
                        DataBufferUtils.release(dataBuffer);

                        // 合并累积的字节和当前字节
                        byte[] combined = new byte[state.accumulatedBytes.length + currentBytes.length];
                        System.arraycopy(state.accumulatedBytes, 0, combined, 0, state.accumulatedBytes.length);
                        System.arraycopy(currentBytes, 0, combined, state.accumulatedBytes.length, currentBytes.length);

                        // 找到最后一个完整的UTF-8字符边界
                        int validEnd = findLastValidUtf8Boundary(combined);

                        if (validEnd > 0) {
                            // 转换完整的字符部分
                            byte[] validBytes = new byte[validEnd];
                            System.arraycopy(combined, 0, validBytes, 0, validEnd);
                            String newContent = new String(validBytes, StandardCharsets.UTF_8);

                            // 检查是否有完整的SSE消息
                            String fullContent = state.accumulatedString + newContent;

                            // 查找完整的SSE消息边界
                            int lastCompleteMessageEnd = findLastCompleteSSEMessage(fullContent);

                            if (lastCompleteMessageEnd > 0) {
                                // 提取完整的SSE消息
                                String completeMessage = fullContent.substring(0, lastCompleteMessageEnd);
                                String remainingString = fullContent.substring(lastCompleteMessageEnd);

                                // 保留未处理的字节
                                byte[] remainingBytes = new byte[combined.length - validEnd];
                                if (remainingBytes.length > 0) {
                                    System.arraycopy(combined, validEnd, remainingBytes, 0, remainingBytes.length);
                                }

                                log.debug("发现完整SSE消息，长度: {}, 剩余字符串: {}, 剩余字节: {}",
                                        completeMessage.length(), remainingString.length(), remainingBytes.length);

                                return new BufferState(remainingBytes, remainingString, completeMessage);
                            } else {
                                // 没有完整消息，继续累积
                                byte[] remainingBytes = new byte[combined.length - validEnd];
                                if (remainingBytes.length > 0) {
                                    System.arraycopy(combined, validEnd, remainingBytes, 0, remainingBytes.length);
                                }
                                return new BufferState(remainingBytes, fullContent, "");
                            }
                        } else {
                            // 没有有效字符，保持原状态
                            return new BufferState(combined, state.accumulatedString, "");
                        }
                    } catch (Exception e) {
                        log.error("处理数据缓冲时发生异常: {}", e.getMessage());
                        DataBufferUtils.release(dataBuffer);
                        return state;
                    }
                })
                .skip(1) // 跳过初始状态
                .filter(state -> !state.completeMessage.isEmpty())
                .map(state -> state.completeMessage)
                .filter(data -> data.contains("id:") || data.contains("event:") || data.contains("data:"))
                .doOnNext(data -> {
                    // 统计处理的数据块
                    int total = totalProcessedBlocks.incrementAndGet();

                    // 检查是否包含乱码字符
                    boolean hasCorruption = data.contains("");
                    if (hasCorruption) {
                        int corrupted = corruptedBlocks.incrementAndGet();
//                        log.warn("检测到UTF-8字符截断，数据块: {}, 总计: {}, 损坏: {}",
//                                data.substring(0, Math.min(50, data.length())), total, corrupted);
                    } else {
                        int fixed = fixedBlocks.incrementAndGet();
//                        log.debug("UTF-8字符完整性检查通过，数据块: {}, 总计: {}, 修复: {}",
//                                data.substring(0, Math.min(50, data.length())), total, fixed);
                    }

                    log.debug("成功处理SSE数据块，长度: {}, 包含id: {}, 包含event: {}, 包含data: {}",
                            data.length(),
                            data.contains("id:"),
                            data.contains("event:"),
                            data.contains("data:"));
                })
                .doOnComplete(() -> {
                    log.info("SSE数据流处理完成");
                    logUtf8ProcessingStats();
                })
                .doOnError(error -> {
                    log.error("SSE数据流处理出错: {}", error.getMessage());
                    logUtf8ProcessingStats();
                });
    }

    /**
     * 缓冲状态类，用于保存累积的字节、字符串和完整消息
     */
    private static class BufferState {
        final byte[] accumulatedBytes;
        final String accumulatedString;
        final String completeMessage;

        public BufferState() {
            this.accumulatedBytes = new byte[0];
            this.accumulatedString = "";
            this.completeMessage = "";
        }

        public BufferState(byte[] accumulatedBytes, String accumulatedString, String completeMessage) {
            this.accumulatedBytes = accumulatedBytes;
            this.accumulatedString = accumulatedString;
            this.completeMessage = completeMessage;
        }
    }

    /**
     * 查找最后一个完整SSE消息的结束位置
     */
    private int findLastCompleteSSEMessage(String content) {
        // 查找所有的双换行符位置（SSE消息分隔符）
        int lastDoubleNewline = content.lastIndexOf("\n\n");
        if (lastDoubleNewline >= 0) {
            return lastDoubleNewline + 2; // 包含双换行符
        }

        // 查找单独的data:行结束
        int lastDataEnd = -1;
        String[] lines = content.split("\n");
        for (int i = 0; i < lines.length; i++) {
            if (lines[i].startsWith("data:")) {
                // 找到data行，检查是否完整（以换行符结束）
                int dataLineStart = content.indexOf(lines[i]);
                int dataLineEnd = dataLineStart + lines[i].length();
                if (dataLineEnd < content.length() && content.charAt(dataLineEnd) == '\n') {
                    lastDataEnd = dataLineEnd + 1;
                }
            }
        }

        return lastDataEnd;
    }

    /**
     * 找到字节数组中最后一个完整UTF-8字符的边界
     * 避免在多字节字符中间截断
     */
    private int findLastValidUtf8Boundary(byte[] bytes) {
        if (bytes.length == 0) return 0;

        // 从末尾向前查找，找到完整的UTF-8字符边界
        for (int i = bytes.length - 1; i >= 0; i--) {
            byte b = bytes[i];

            // ASCII字符 (0xxxxxxx)
            if ((b & 0x80) == 0) {
                return i + 1;
            }

            // UTF-8多字节字符的开始字节
            if ((b & 0xC0) == 0xC0) {
                // 检查这个多字节字符是否完整
                int expectedLength = getUtf8CharLength(b);
                if (i + expectedLength <= bytes.length) {
                    return i + expectedLength;
                } else {
                    // 字符不完整，返回这个字符之前的位置
                    return i;
                }
            }
        }

        return bytes.length;
    }

    /**
     * 根据UTF-8首字节确定字符长度
     */
    private int getUtf8CharLength(byte firstByte) {
        if ((firstByte & 0x80) == 0) return 1;      // 0xxxxxxx
        if ((firstByte & 0xE0) == 0xC0) return 2;  // 110xxxxx
        if ((firstByte & 0xF0) == 0xE0) return 3;  // 1110xxxx
        if ((firstByte & 0xF8) == 0xF0) return 4;  // 11110xxx
        return 1; // 默认值
    }

    /**
     * 输出UTF-8字符处理统计信息
     */
    private void logUtf8ProcessingStats() {
        int total = totalProcessedBlocks.get();
        int corrupted = corruptedBlocks.get();
        int fixed = fixedBlocks.get();

        if (total > 0) {
            double corruptionRate = (double) corrupted / total * 100;
            double fixRate = (double) fixed / total * 100;

            log.info("UTF-8字符处理统计 - 总处理: {}, 损坏: {} ({:.2f}%), 修复: {} ({:.2f}%)",
                    total, corrupted, corruptionRate, fixed, fixRate);
        }
    }
}
